{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Solving Cart-Pole with SVP\n",
    "\n",
    "This notebook shows how to use the proposed structured value-based planning (SVP) approach to generate the state-action $Q$-value function for the classic cart-pole problem. The correctness of the solution is verified by trajectory simulations as well as the final metrics comparisons."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Problem definition\n",
    "\n",
    "We here focus on the cart-pole problem, which has a higher state space dimension, i.e., 4 dims.\n",
    "\n",
    "The cart-pole problem consists in balancing in upright position a pole attached to a cart moving on a frictionless track. The cart can be controlled by means of a limited force within 10$N$ that is possible to apply both to the left or to the right of the cart. The goal is to keep the pole on the upright equilibrium position.\n",
    "The physical dynamics of the system is described by the angle and the angular speed of the pole, and the position and the speed of the cart, i.e., $(\\theta, \\dot{\\theta}, x, \\dot{x})$. Denote $\\tau$ as the time interval between decisions, $u$ as the force input on the cart, the dynamics can be written as\n",
    "\n",
    "\\begin{align}\n",
    "    & \\ddot{\\theta} := \\frac{g\\sin{\\theta} - \\frac{u+ml \\dot{\\theta}^2 \\sin{\\theta}}{m_c+m} \\cos{\\theta}}{l\\left( \\frac{4}{3} - \\frac{m\\cos^2{\\theta}}{m_c+m} \\right)},\\\\\n",
    "    & \\ddot{x} := \\frac{u + ml\\left( \\dot{\\theta}^2 \\sin{\\theta} - \\ddot{\\theta}\\cos{\\theta}\\right) }{m_c+m},\\\\\n",
    "    & \\theta := \\theta + \\dot{\\theta}~\\tau,\\\\\n",
    "    & \\dot{\\theta} := \\dot{\\theta} + \\ddot{\\theta}~\\tau,\\\\\n",
    "    & x := x + \\dot{x}~\\tau,\\\\\n",
    "    & \\dot{x} := \\dot{x} + \\ddot{x}~\\tau,\\\\\n",
    "\\end{align}\n",
    "\n",
    "where $g=9.8m/s^2$ corresponds to the gravity acceleration, $m_c = 1kg$ to the mass of the cart, $m = 0.1kg$ to the mass of the pole, $l = 0.5m$ to half of the pole length, and $u$ corresponds to the force applied to the cart.\n",
    "\n",
    "A reward function that favors the pole in an upright position, i.e., characterized by keeping the pole in vertical position between $\\left|\\theta\\right| \\le \\frac{12\\pi}{180}$, is expressed as\n",
    "\n",
    "\\begin{equation}\n",
    "    r(\\theta) = \\cos^4{(15 \\theta)}.\n",
    "\\end{equation}\n",
    "\n",
    "In the simulation, the state space is $[-\\frac{\\pi}{2}, \\frac{\\pi}{2}]$ for $\\theta$, $[-3,3]$ for $\\dot{\\theta}$, $[-2.4,2.4]$ for $x$ and $[-3.5,3.5]$ for $\\dot{x}$. We limit the input force in $[-10,10]$ and set $\\tau=0.1$.\n",
    "We discretize each dimension of the state space into 10 values, and action space into 1000 values, which forms an $Q$-value function  a matrix of dimension $10000\\times 1000$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## Structured Value-based Planning (SVP)\n",
    "\n",
    "The proposed structured value-based planning (SVP) approach is based on the $Q$-value iteration. At the $t$-th iteration, instead of a full pass over all state-action pairs:\n",
    "- SVP first randomly selects a subset $\\Omega$ of the state-action pairs. In particular, each state-action pair in $\\mathcal{S}\\times\\mathcal{A}$ is observed (i.e., included in $\\Omega$) independently with probability $p$. \n",
    "- For each selected $(s,a)$, the intermediate $\\hat{Q}(s,a)$ is computed based on the $Q$-value iteration: \n",
    "    \\begin{equation*}\\hat{Q}(s,a) \\leftarrow \\sum_{s'} P(s'|s,a) \\left( r(s,a) + \\gamma \\max_{a'} Q^{(t)}(s',a') \\right),\\quad\\forall\\:(s,a)\\in\\Omega.\n",
    "    \t\t\t\t\\end{equation*}\n",
    "- The current iteration then ends by reconstructing the full $Q$ matrix with matrix estimation, from the set of observations in $\\Omega$. That is, $Q^{(t+1)}=\\textrm{ME}\\big(\\{\\hat{Q}(s,a)\\}_{(s,a)\\in\\Omega}\\big).$\n",
    "\n",
    "Overall, each iteration reduces the computation cost by roughly $1-p$."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "subslide"
    }
   },
   "source": [
    "Through SVP, we can compute the final state-action $Q$-value function.\n",
    "To obtain the optimal policy for state $s$, we compute\n",
    "\n",
    "\\begin{align*}\n",
    "    \\pi^{\\star} \\left(s\\right) = \\mbox{argmax}_{a \\in \\mathcal{A}} Q^{\\star}\\left(s, a\\right).\n",
    "\\end{align*}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### Generate state-action value function with SVP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [],
   "source": [
    "push!(LOAD_PATH, \".\")\n",
    "using MDPs, CartPole, Printf, LinearAlgebra\n",
    "mdp = MDP(state_space(), action_space(), transition, reward)\n",
    "__init__()\n",
    "policy = value_iteration(mdp, true, \"../data/qcp_otf_0.2.csv\", true)\n",
    "print(\"\")  # suppress output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### Visualize policy as a heat map\n",
    "\n",
    "Note that the cart-pole problem is a 4-dimensional task, and the policy heat map should have 4 dims (but hard to visualize). Since the metric is the angle deviation, we here only plot the first two dims (i.e., the `angle` and `angular speed`) to visualize the heat map."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [],
   "source": [
    "viz_policy(mdp, policy, \"SVP policy (20% observed)\", true, \"cp/policy_cp_0.2.tex\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "### Verify correctness\n",
    "Simulate and visualize trajectory from initial state `[angle, angular_speed, position, x_speed]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [],
   "source": [
    "ss, as = simulate(mdp, policy, [-0.5, 0.0, 0.0, 0.0])\n",
    "viz_trajectory(ss, as, \"SVP trajectory (20%)\", \"SVP input (20%)\", true, \"cp/traj_cp_0.2.tex\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Warning: `vecnorm` is deprecated, use `norm` instead.\n",
      "│   caller = top-level scope at In[4]:9\n",
      "└ @ Core ./In[4]:9\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average deviation: 10.076\n"
     ]
    }
   ],
   "source": [
    "nsim = 2000\n",
    "deviation = 0\n",
    "for sim = 1:nsim\n",
    "    state = [rand(THETAMIN:0.001 * (THETAMAX - THETAMIN):THETAMAX), \n",
    "             rand(THETADOTMIN:0.001 * (THETADOTMAX - THETADOTMIN):THETADOTMAX),\n",
    "             rand(XMIN:0.001 * (XMAX - XMIN):XMAX), \n",
    "             rand(XDOTMIN:0.001 * (XDOTMAX - XDOTMIN):XDOTMAX)]\n",
    "    traj, _ = simulate(mdp, policy, copy(state))\n",
    "    deviation += vecnorm(traj[51:end, 1])\n",
    "end # for sim\n",
    "deviation /= nsim\n",
    "@printf(\"average deviation: %.3f\\n\", deviation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Slideshow",
  "kernelspec": {
   "display_name": "Julia 0.7.0",
   "language": "julia",
   "name": "julia-0.7"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
